import time
import json
import math
import conf.settings as settings
from multiprocessing import Queue, Process
from threading import Thread
from schedule import Scheduler
from os import getpid, kill
from typing import Dict, List
from clogger import logger
from sysom_utils import SysomFramework
from lib.algorithm.weight_algorithm import WeightsCalculator,\
    TypeMetricWeights, TYPES
from lib.score_result import ScoreResult, ScoreType, TypeResult, LevelResults
from lib.metric_manager import MetricManager, Level
from lib.utils import collect_all_clusters, collect_instances_of_cluster, \
    collect_pods_of_instance
from lib.metric_exception import MetricProcessException


class Analyzer(Process):
    def __init__(
        self,
        clusterhealth_interval: int = 60,
        queue: Queue = None,
        metric_manager: MetricManager = None,
        weight_cal: WeightsCalculator = None,
        parent_pid: int = None
    ) -> None:
        super().__init__(daemon=True)
        self.clusterhealth_interval = clusterhealth_interval
        self.clusterhealth_host_schedule: Scheduler = Scheduler()
        self.metric_manager = metric_manager
        self.weight_cal = weight_cal
        self.last_end_time = time.time() - self.clusterhealth_interval
        self.last_alarm_table: Dict[str, int] = {}
        self.queue = queue
        self.parent_pid = parent_pid
    
    def check_if_parent_is_alive(self):
        try:
            kill(self.parent_pid, 0)
        except OSError:
            logger.info(f"Analyzer's parent {self.parent_pid} is exit")
            exit(0)

    def deliver_one_alarm(self, cluster: str, instance: str, pod: str,
                          type: str, level: Level, score: float,
                          value: float, metric):

        if metric.settings.alarm is None:
            return

        threshold = metric.settings.alarm.threshold
        description = metric.settings.description

        key = f"{cluster}-{instance}-{pod}-{description}"

        # score lower than threshold, deliver alarm
        if score <= threshold:
            if key not in self.last_alarm_table:
                self.last_alarm_table[key] = 0

            continue_alarm = self.last_alarm_table[key]
            # first alarm, deliver it and raise diagnose
            if continue_alarm == 0:
                alart_id = metric.deliver_alarm(value, type)
                metric.deliver_diagnose(alart_id, level, type, self.queue)

            self.last_alarm_table[key] += 1

            # if alarm list is longer than MERGE_NUM, resend the alarms
            if continue_alarm > settings.ALARM_MERGE_NUM:
                self.last_alarm_table[key] = 0
        else:
            # if continuesly alarm end, reset the alarm list
            if key in self.last_alarm_table:
                del self.last_alarm_table[key]

    def _get_metric_score(self, level: Level,
                          labels: Dict[str, str]) -> TypeResult:
        type_res = TypeResult({})
        type_weights_list = self.weight_cal.type_weights[level]
        registed_metric = self.metric_manager.registed_metric[level]

        for type_weights in type_weights_list:
            metric_type = type_weights.type
            if type_weights.weight == 0:
                continue

            pod = labels.get("pod", "")
            instance = labels.get("instance", "")
            cluster = labels.get("cluster", "")
            type_res[metric_type] = []

            for metric in registed_metric[metric_type]:
                try:
                    # get metric's value and score
                    value, score = metric.metric_score(pod, instance,
                                                       cluster,
                                                       self.last_end_time)
                    # deliver alarm
                    self.deliver_one_alarm(cluster, instance, pod, metric_type,
                                           level, score, value, metric)
                except Exception as e:
                    raise e

                # construct metric labels
                metric_labels = labels.copy()
                metric_labels["type"] = metric_type
                metric_labels["description"] = metric.settings.description

                type_res[metric_type].append(
                    ScoreResult(
                        metric_labels, score, value, ScoreType.MetricScore
                    ).to_dict()
                )

        return type_res

    def _cal_one(self, level: Level,
                 labels: Dict[str, str]) -> List[ScoreResult]:
        final_score = 0
        result = []
        pod = labels.get("pod", "")
        instance = labels.get("instance", "")
        type_weights_list = self.weight_cal.type_weights[level]
        weights_method = self.weight_cal.weights_method[level]

        for type_weights in type_weights_list:
            type = type_weights.type
            type_weight = type_weights.weight
            metrics_score = []
            type_score = 0

            if type_weight == 0:
                continue

            for metric in self.metric_manager.registed_metric[level][type]:
                try:
                    value, score = metric.metric_score(pod, instance,
                                                       labels["cluster"],
                                                       self.last_end_time)
                except MetricProcessException as e:
                    logger.info(f"Calculate Metric: {metric.settings.description} "
                                f"of Pod: {pod} of Node: {instance} failed {e}")
                    continue

                self.deliver_one_alarm(labels["cluster"], instance, pod, type,
                                       level, score, value, metric)

                metric_labels = labels.copy()
                metric_labels["type"] = type
                metric_labels["description"] = metric.settings.description

                result.append(
                    ScoreResult(metric_labels, score, value, ScoreType
                                .MetricScore).to_dict()
                )

                metrics_score.append(score)
                if weights_method == "WeightedSum":
                    type_score += score * metric.settings.score.weight
            if weights_method == "Equal":
                type_score = sum(metrics_score) / len(metrics_score)
            elif weights_method == "Worst":
                type_score = min(metrics_score)

            type_labels = labels.copy()
            type_labels["type"] = type
            result.append(
                ScoreResult(type_labels, type_score, 0, ScoreType.
                            MetricTypeScore).to_dict()
            )
            # final score of the pod is the weight sum of each type of metric
            # score
            final_score += type_score * type_weight

        # final score store in the last elements of the result list
        final_score = math.floor(final_score)
        result.append(
            ScoreResult(
                labels, final_score, 0, ScoreType.InstanceScore
            ).to_dict()
        )

        return result

    def _cal_one_auto_weights(
        self, level: Level,
        item: LevelResults,
        metric_weights: TypeMetricWeights
    ) -> List[ScoreResult]:

        type_weights = self.weight_cal.type_weights[level]

        final_result: List[ScoreResult] = []
        base_labels = item.labels
        item_score = 0

        for type_weight in type_weights:
            type = type_weight.type
            weight = type_weight.weight
            metric_scores = item.results[type]

            type_score = sum(
                [metric_scores[i]["score"] * metric_weights[type][i]
                    for i in range(len(metric_scores))]
            )

            type_labels = base_labels.copy()
            type_labels["type"] = type
            final_result.extend(metric_scores)
            final_result.append(
                ScoreResult(
                    type_labels, type_score, 0, ScoreType.MetricTypeScore
                ).to_dict()
            )

            item_score += type_score * weight

        item_score = math.floor(item_score)
        final_result.append(
            ScoreResult(
                base_labels, item_score, 0, ScoreType.InstanceScore
            ).to_dict()
        )

        return final_result

    def _cal_pods_auto_weights(
        self, cluster: str, instance: str,
        pod_list, g_cache
    ) -> List[float]:

        level = Level.Pod
        pods_score = []
        pods_result = []

        for pod, podns in pod_list:
            labels = {
                "pod": pod,
                "namespace": podns,
                "instance": instance,
                "cluster": cluster,
            }

            try:
                pods_result.append(
                    LevelResults(
                        labels,
                        self._get_metric_score(level, labels)
                    )
                )
            except Exception as e:
                logger.error(f"Collect metric of Pod: {pod} of "
                             f"Node: {instance} failed: {e}")
                return []

        metric_weights = self.weight_cal.cal_metric_weights(level, pods_result)
        for pod in pods_result:
            pod_result = self._cal_one_auto_weights(level, pod, metric_weights)
            pod_key = pod.labels["pod"] + f"-{pod.labels['namespace']}"
            pods_score.append(pod_result[-1]["score"])
            g_cache.store(pod_key, json.dumps(pod_result))

        return pods_score

    def _cal_one_pod(self, cluster: str, pod: str,
                     pod_ns: str, instance: str, g_cache_pod) -> float:
        result = []
        final_pod_score = 0

        try:
            labels = {
                "pod": pod,
                "namespace": pod_ns,
                "instance": instance,
                "cluster": cluster,
            }

            result = self._cal_one(Level.Pod, labels)
            final_pod_score = result[-1]["score"]

            pod_key = pod + f"-{pod_ns}"
            g_cache_pod.store(pod_key, json.dumps(result))

        except Exception as e:
            logger.error(f"Calculate score of Pod: {pod} of "
                         f"Node: {instance} failed: {e}")
            return None

        return final_pod_score

    def _cal_pods(self, cluster: str, instance: str) -> List[float]:
        g_cache_pod = SysomFramework.gcache("pod_metrics")
        pod_method = self.weight_cal.weights_method[Level.Pod]
        pod_results = []

        pod_list = collect_pods_of_instance(instance,
                                            self.metric_manager.metric_reader,
                                            self.clusterhealth_interval)

        if pod_method in ["WeightedSum", "Equal", "Worst"]:
            for pod, pod_ns in pod_list:
                pod_result = self._cal_one_pod(cluster, pod,
                                               pod_ns, instance, g_cache_pod)
                if pod_result is not None:
                    pod_results.append(pod_result)
        else:
            pod_results = self._cal_pods_auto_weights(cluster, instance,
                                                      pod_list, g_cache_pod)

        return pod_results

    def _combine_pods_node(self, pod_res: List[float],
                           node_score: float) -> float:
        weights = [0.7, 0.3]

        if len(pod_res) == 0:
            return node_score

        pod_avg = sum(pod_res) / len(pod_res)

        return node_score * weights[0] + pod_avg * weights[1]

    def _cal_nodes_auto_weights(self, cluster: str,
                                instance_list: List[str], g_cache):

        level = Level.Node
        instances_result = []

        for instance in instance_list:
            labels = {
                "instance": instance,
                "cluster": cluster,
            }

            try:
                instances_result.append(
                    LevelResults(
                        labels,
                        self._get_metric_score(level, labels)
                    )
                )
            except Exception as e:
                logger.error(f"Collect metric of Node: {instance} failed: {e}")
                return

        metric_weights = self.weight_cal.cal_metric_weights(
            level, instances_result)
        for instance in instances_result:
            pod_res = self._cal_pods(cluster, instance.labels["instance"])
            node_res = self._cal_one_auto_weights(
                level, instance, metric_weights)
            # combine pod score and node score and update it
            node_score = math.floor(
                self._combine_pods_node(pod_res, node_res[-1]["score"])
            )
            node_res[-1]["score"] = node_score

            g_cache.store(instance.labels["instance"],
                          json.dumps(node_res))

    def _cal_one_node(self, cluster: str, instance: str,
                      g_cache_instance) -> float:
        node_score = 0
        node_result = []

        pod_results = self._cal_pods(cluster, instance)

        try:
            labels = {
                "instance": instance,
                "cluster": cluster,
            }

            node_result = self._cal_one(Level.Node, labels)
            node_score = math.floor(
                self._combine_pods_node(pod_results, node_result[-1]["score"])
            )
            node_result[-1]["score"] = node_score

            # todo pod的健康分如何纳入node健康分的计算
            g_cache_instance.store(instance, json.dumps(node_result))
        except Exception as e:
            logger.error(f"Calculating score of Node: {instance} failed")
            logger.exception(e)
            return None

        return node_score

    def _cal_nodes_task(self, i: int, cluster: str, instance_list: List[str]):
        g_cache_instance = SysomFramework.gcache("instance_metrics")
        node_method = self.weight_cal.weights_method[Level.Node]
        metric_per_processor = len(instance_list) / \
            settings.ANALYZER_PROCESS_NUM

        if i == settings.ANALYZER_PROCESS_NUM:
            assigned_max = len(instance_list)
        else:
            assigned_max = int(metric_per_processor * i)

        assigned_min = assigned_max - int(metric_per_processor)
        assigned_node = range(assigned_min, assigned_max)

        if node_method in ["WeightedSum", "Equal", "Worst"]:
            for j in assigned_node:
                instance = instance_list[j]
                self._cal_one_node(cluster, instance,
                                   g_cache_instance)
        else:
            self._cal_nodes_auto_weights(
                cluster,
                instance_list[assigned_min:assigned_max],
                g_cache_instance
            )

    def _cal_one_cluster(self, cluster: str):
        def __cal_nodes_multi_thread():
            threads = []
            for i in range(1, settings.ANALYZER_PROCESS_NUM + 1):
                if i > len(instances_list):
                    logger.warning("process num is set to be"
                                   " larger than instance num!")
                    break

                t = Thread(target=self._cal_nodes_task,
                           args=(i, cluster, instances_list))
                threads.append(t)
                t.start()

            for t in threads:
                t.join()

        def __cal_nodes_normal():
            g_cache_instance = SysomFramework.gcache("instance_metrics")
            node_method = self.weight_cal.weights_method[Level.Node]

            if node_method in ["WeightedSum", "Equal", "Worst"]:
                for instance in instances_list:
                    self._cal_one_node(cluster, instance,
                                       g_cache_instance)
            else:
                self._cal_nodes_auto_weights(cluster,
                                             instances_list, g_cache_instance)

        def __nodes_to_cluster(labels: Dict[str, str]) -> List[ScoreResult]:
            final_score = 0
            cluster_res = []
            type_score = {}
            nodes_score = []
            g_cache_instance = SysomFramework.gcache("instance_metrics")

            for type in TYPES:
                type_score[type] = []

            instances = g_cache_instance.load_all()
            for _, instance_res in instances.items():
                res = json.loads(instance_res)
                for metric in res:
                    if metric["type"] == ScoreType.MetricTypeScore.value:
                        type_score[metric["labels"]["type"]].append(
                            metric["score"])
                    elif metric["type"] == ScoreType.InstanceScore.value:
                        nodes_score.append(metric["score"])

            for type in TYPES:
                type_labels = labels.copy()
                type_labels["type"] = type

                if len(type_score[type]) == 0:
                    logger.warning(f"No Nodes's {type} score")
                    continue
                # cluster type score = avg(nodes' type score)
                avg_score = sum(type_score[type]) / len(type_score[type])
                cluster_res.append(
                    ScoreResult(
                        type_labels, avg_score, 0,
                        ScoreType.MetricTypeScore
                    ).to_dict()
                )

            try:
                # cluster score = avg(nodes' score)
                final_score = math.floor(sum(nodes_score) / len(nodes_score))
                cluster_res.append(
                    ScoreResult(
                        labels, final_score, 0, ScoreType.InstanceScore
                    ).to_dict()
                )
            except ZeroDivisionError as e:
                logger.info("no nodes in cluster!")
                raise e 

            return cluster_res

        cluster_result = []
        g_cache_cluster = SysomFramework.gcache("cluster_metrics")
        instances_list = collect_instances_of_cluster(
            cluster,
            self.metric_manager.metric_reader,
            self.clusterhealth_interval
        )

        labels = {
            "cluster": cluster,
        }

        if settings.ENABLE_MULTI_THREAD is True:
            __cal_nodes_multi_thread()
        else:
            __cal_nodes_normal()

        try:
            # cluster_result = self._cal_one(Level.Cluster, labels)
            cluster_result = __nodes_to_cluster(labels)
            g_cache_cluster.store(cluster, json.dumps(cluster_result))

        except Exception as e:
            logger.error(f"Calculating score of Cluster: {cluster} failed")
            logger.exception(e)
            pass

    def _register_task(self):
        cluster_list = []
        
        cluster_list = collect_all_clusters(self.metric_manager.metric_reader)
        # no cluster label, we assume just one, and names it "dafault"
        if len(cluster_list) == 0 or settings.NO_CLUSTER_LABEL is True:
            cluster_list.append("default")

        start_time = time.time()

        for cluster in cluster_list:
            self._cal_one_cluster(cluster)

        self.last_end_time = time.time()
        end_time = time.time()
        logger.info(f"Excutaion time: {end_time - start_time}")

    def run(self) -> None:
        logger.info(f'健康度计算守护进程PID： {getpid()}')

        self._register_task()
        self.clusterhealth_host_schedule.every(self.clusterhealth_interval)\
            .seconds.do(self._register_task)

        while True:
            self.check_if_parent_is_alive();
            
            if self.is_alive():
                self.clusterhealth_host_schedule.run_pending()
            else:
                break
            time.sleep(max(1, int(self.clusterhealth_interval / 2)))
